{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 化合物表示学习和性质预测\n",
    "\n",
    "在这篇教程中，我们将介绍如何运用图神经网络（GNN）模型来预测化合物的性质。具体来说，我们将演示如何对其进行预训练（pretrain），如何针对下游任务进行模型微调（finetune），并利用最终的模型进行推断（inference）。如果你想了解更多细节，请查阅 \"[info graph](https://github.com/PaddlePaddle/PaddleHelix/apps/pretrained_compound/info_graph/README_cn.md)\" 和 \"[pretrained GNN](https://github.com/PaddlePaddle/PaddleHelix/apps/pretrained_compound/pretrain_gnns/README_cn.md)\" 的详细解释.\n",
    "\n",
    "# 第一部分：预训练\n",
    "\n",
    "在这一部分，我们将展示如何预训练一个化合物 GNN 模型。本文中的预训练技术是在预训练 GNN 的基础上发展起来的，包括 attribute masking、context prediction 和有监督预训练。\n",
    "更多细节请查看文件：`pretrain_attrmask.py`，`pretrain_contextpred.py` 和 `pretrain_supervised.py`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.insert(0, os.getcwd() + \"/..\")\n",
    "os.chdir(\"../apps/pretrained_compound/pretrain_gnns\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PaddleHelix 是构建于 PaddlePaddle 之上的生物计算深度学习框架。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO] 2020-12-18 20:18:19,496 [mp_reader.py:   23]:\tujson not install, fail back to use json instead\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.incubate.fleet.collective import fleet\n",
    "from pahelix.datasets import load_zinc_dataset\n",
    "from pahelix.featurizers import PreGNNAttrMaskFeaturizer\n",
    "from pahelix.utils.compound_tools import CompoundConstants\n",
    "from pahelix.model_zoo import PreGNNAttrmaskModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    }
   ],
   "source": [
    "# switch to paddle static graph mode.\n",
    "paddle.enable_static()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建静态图\n",
    "\n",
    "通常情况下，我们用 Paddle 提供的 `Program` 和 `Executor` 来构建静态图。这里，我们使用 `model_config` 保存模型配置。`PreGNNAttrmaskModel` 是一种无监督的预训练模型，它随机地对某个节点的原子类型进行 mask，然后再尝试去预测这个原子的类型。同时，我们使用 Adam 优化器并将学习率（learning rate）设置为 0.001。\n",
    "\n",
    "若要使用GPU进行训练，请取消注释行 `exe = fluid.Executor(fluid.CUDAPlace(0))`。当 `fluid.CPUPlace()` 用于CPU训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/model_zoo/pretrain_gnns_model.py:78\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n",
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/model_zoo/pretrain_gnns_model.py:98\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n",
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/networks/gnn_block.py:194\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var mean_0.tmp_0 : LOD_TENSOR.shape(1,).dtype(FP32).stop_gradient(False)\n"
     ]
    }
   ],
   "source": [
    "model_config = {\n",
    "    \"dropout_rate\": 0.5,# dropout rate\n",
    "    \"gnn_type\": \"gin\",  # other choices like \"gat\", \"gcn\".\n",
    "    \"layer_num\": 5,     # the number of gnn layers.\n",
    "}\n",
    "train_prog = fluid.Program()\n",
    "startup_prog = fluid.Program()\n",
    "with fluid.program_guard(train_prog, startup_prog):\n",
    "    with fluid.unique_name.guard():\n",
    "        model = PreGNNAttrmaskModel(model_config=model_config)\n",
    "        model.forward()\n",
    "        opt = fluid.optimizer.Adam(learning_rate=0.001)\n",
    "        opt.minimize(model.loss)\n",
    "\n",
    "exe = fluid.Executor(fluid.CPUPlace())\n",
    "# exe = fluid.Executor(fluid.CUDAPlace(0))\n",
    "exe.run(startup_prog)\n",
    "print(model.loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据集加载和特征提取\n",
    "\n",
    "我们首先使用 `wget` 来下载一个小型的测试数据集，如果你的本地计算机上没有 `wget`，你也可以复制下面的链接到你的浏览器中来下载数据。但是请注意你需要把数据包移动到这个路径：\"../apps/pretrained_compound/pretrain_gnns/\"。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-12-18 20:18:27--  https://baidu-nlp.bj.bcebos.com/PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\n",
      "Connecting to 172.19.56.199:3128... connected.\n",
      "WARNING: certificate common name “*.bcebos.com” doesn’t match requested host name “baidu-nlp.bj.bcebos.com”.\n",
      "Proxy request sent, awaiting response... 200 OK\n",
      "Length: 609563 (595K) [application/gzip]\n",
      "Saving to: “PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz.1”\n",
      "\n",
      "100%[======================================>] 609,563      266K/s   in 2.2s    \n",
      "\n",
      "2020-12-18 20:18:32 (266 KB/s) - “PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz.1” saved [609563/609563]\n",
      "\n",
      "tox21  zinc_standard_agent\n"
     ]
    }
   ],
   "source": [
    "### Download a toy dataset for demonstration:\n",
    "!wget \"https://baidu-nlp.bj.bcebos.com/PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\" --no-check-certificate\n",
    "!tar -zxf \"PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz\"\n",
    "!ls \"./chem_dataset_small\"\n",
    "### Download the full dataset as you want:\n",
    "# !wget \"http://snap.stanford.edu/gnn-pretrain/data/chem_dataset.zip\" --no-check-certificate\n",
    "# !unzip \"chem_dataset.zip\"\n",
    "# !ls \"./chem_dataset\""
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "使用 `PreGNNAttrMaskFeaturizer` 来配合模型 `PreGNNAttrmaskModel`。它继承了用于特征提取的超类 `Featurizer`。`Featurizer` 有两个功能：`gen_features` 用于将一条原始 SMILES 转换为图数据，而 `collate_fn` 用于将图数据的子列表聚合为一个 batch。这里我们采用 Zinc 数据集来进行预训练。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset num: 1000\n"
     ]
    }
   ],
   "source": [
    "featurizer = PreGNNAttrMaskFeaturizer(\n",
    "        model.graph_wrapper, \n",
    "        atom_type_num=len(CompoundConstants.atom_num_list),\n",
    "        mask_ratio=0.15)\n",
    "### Load the first 1000 of the toy dataset for speed up\n",
    "dataset = load_zinc_dataset(\"./chem_dataset_small/zinc_standard_agent/raw\", featurizer=featurizer)\n",
    "dataset = dataset[:1000]\n",
    "### Load the full dataset:\n",
    "# dataset = load_zinc_dataset(\"./chem_dataset/zinc_standard_agent/raw\", featurizer=featurizer)\n",
    "print(\"dataset num: %s\" % (len(dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 启动训练\n",
    "\n",
    "现在我们开始训练 Attrmask 模型。我们仅训练两个 epoch 作为演示，数据加载的过程通过4个 `workers` 进行了加速。然后我们将预训练后的模型保存到 \"./model/pretrain_attrmask\"，作为下游任务的初始模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0 train/loss:1.042354\n",
      "epoch:1 train/loss:0.85617626\n"
     ]
    }
   ],
   "source": [
    "def train(exe, train_prog, model, dataset, featurizer):\n",
    "    data_gen = dataset.iter_batch(\n",
    "            batch_size=256, num_workers=4, shuffle=True, collate_fn=featurizer.collate_fn)\n",
    "    list_loss = []\n",
    "    for batch_id, feed_dict in enumerate(data_gen):\n",
    "        train_loss, = exe.run(train_prog, \n",
    "                feed=feed_dict, fetch_list=[model.loss], return_numpy=False)\n",
    "        list_loss.append(np.array(train_loss).mean())\n",
    "    return np.mean(list_loss)\n",
    "\n",
    "for epoch_id in range(2):\n",
    "    train_loss = train(exe, train_prog, model, dataset, featurizer)\n",
    "    print(\"epoch:%d train/loss:%s\" % (epoch_id, train_loss))\n",
    "fluid.io.save_params(exe, './model/pretrain_attrmask', train_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型预训练的内容到此为止，你可以根据自己的需要对上面的参数进行调整。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第二部分：下游任务模型微调（fintune）\n",
    "\n",
    "下面我们将介绍如何对预训练的模型进行微调来适应下游任务。\n",
    "\n",
    "更多细节参见 `finetune.py` 文件中的内容。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pahelix.utils.paddle_utils import load_partial_params\n",
    "from pahelix.utils.splitters import \\\n",
    "    RandomSplitter, IndexSplitter, ScaffoldSplitter, RandomScaffoldSplitter\n",
    "from pahelix.datasets import *\n",
    "\n",
    "from model import DownstreamModel\n",
    "from featurizer import DownstreamFeaturizer\n",
    "from utils import calc_rocauc_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下游任务的数据集通常规模很小，并且面向不同的任务。例如，BBBP 数据集用于预测化合物的血脑屏障通透性；Tox21 数据集用于预测化合物的毒性等。这里我们使用 Tox21 数据集进行演示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']\n"
     ]
    }
   ],
   "source": [
    "task_names = get_default_tox21_task_names()\n",
    "print(task_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建静态图\n",
    "\n",
    "这里我们采用和之前一样的方式构建一个静态图模型。注意这里的模型结构的设置应该和预训练模型中的设置保持一致，否则模型加载将会失败。`DownstreamModel` 是一个有监督的 GNN 模型，用于上述 `task_names` 中定义的预测任务。\n",
    "\n",
    "我们使用 `train_prog` 和 `test_prog` 来保存静态图，用于后续的训练和测试。它们具有相同的网络架构，但某些操作符的功能将发生变化，例如 `Dropout` 和 `BatchNorm`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/apps/pretrained_compound/pretrain_gnns/model.py:90\n",
      "The behavior of expression A * B has been unified with elementwise_mul(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_mul(X, Y, axis=0) instead of A * B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config = {\n",
    "    \"dropout_rate\": 0.5,# dropout rate\n",
    "    \"gnn_type\": \"gin\",  # other choices like \"gat\", \"gcn\".\n",
    "    \"layer_num\": 5,     # the number of gnn layers.\n",
    "    \"num_tasks\": len(task_names), # number of targets to predict for the downstream task.\n",
    "}\n",
    "train_prog = fluid.Program()\n",
    "test_prog = fluid.Program()\n",
    "startup_prog = fluid.Program()\n",
    "with fluid.program_guard(train_prog, startup_prog):\n",
    "    with fluid.unique_name.guard():\n",
    "        model = DownstreamModel(model_config=model_config)\n",
    "        model.train()\n",
    "        opt = fluid.optimizer.Adam(learning_rate=0.001)\n",
    "        opt.minimize(model.loss)\n",
    "with fluid.program_guard(test_prog, fluid.Program()):\n",
    "    with fluid.unique_name.guard():\n",
    "        model = DownstreamModel(model_config=model_config)\n",
    "        model.train(is_test=True)\n",
    "\n",
    "exe = fluid.Executor(fluid.CPUPlace())\n",
    "# exe = fluid.Executor(fluid.CUDAPlace(0))\n",
    "exe.run(startup_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载预训练模型\n",
    "\n",
    "加载预训练阶段得到的模型。这里我们加载模型 \"pretrain_attrmask\" 作为一个例子。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load parameters from ./model/pretrain_attrmask.\n"
     ]
    }
   ],
   "source": [
    "load_partial_params(exe, './model/pretrain_attrmask', train_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据加载和特征提取\n",
    "\n",
    "将 `DownstreamFeaturizer` 与 `DownstreamModel` 一起使用。它继承自用于特征提取的超类 `featureizer`。`featureizer` 有两个功能：`gen_features` 用于将一条原始 SMILES 转换为单个图数据，而 `collate_fn` 用于将图数据的子列表聚合为一个 batch。\n",
    "\n",
    "Tox21 数据集用作下游任务数据集，我们使用 `ScaffoldSplitter` 将数据集拆分为训练/验证/测试集。`ScaffoldSplitter` 首先根据 Bemis-Murcko scaffold 对化合物进行排序，然后从前到后，将参数 `frac_train` 定义的比例的数据作为训练集，将 `frac_valid` 定义的比例的数据作为验证集，其余的作为测试集。`ScaffoldSplitter` 能更好地评价模型对非同分布样本的泛化能力。这里也可以使用其他的拆分器，如 `RandomSplitter`、`RandomScaffoldSplitter` 和 `IndexSplitter`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [20:23:14] WARNING: not removing hydrogen atom without neighbors\n",
      "RDKit WARNING: [20:23:29] WARNING: not removing hydrogen atom without neighbors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train/Valid/Test num: 6264/783/784\n"
     ]
    }
   ],
   "source": [
    "featurizer = DownstreamFeaturizer(model.graph_wrapper)\n",
    "### Load the toy dataset:\n",
    "dataset = load_tox21_dataset(\n",
    "        \"./chem_dataset_small/tox21/raw\", task_names, featurizer=featurizer)\n",
    "### Load the full dataset:\n",
    "# dataset = load_tox21_dataset(\n",
    "#         \"./chem_dataset/tox21/raw\", task_names, featurizer=featurizer)\n",
    "\n",
    "# splitter = RandomSplitter()\n",
    "splitter = ScaffoldSplitter()\n",
    "train_dataset, valid_dataset, test_dataset = splitter.split(\n",
    "        dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)\n",
    "print(\"Train/Valid/Test num: %s/%s/%s\" % (\n",
    "        len(train_dataset), len(valid_dataset), len(test_dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 启动训练\n",
    "\n",
    "出于演示的目的，这里我们只将 attrmask 模型训练了4轮。由于每个下游任务都包含了多个子任务，我们分别计算了每个子任务的 roc-auc，在求其均值作为最后的评估标准。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:0 train/loss:0.2143532\n",
      "epoch:0 val/auc:0.6834385071324883\n",
      "epoch:0 test/auc:0.6735986534078915\n",
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:1 train/loss:0.20883079\n",
      "epoch:1 val/auc:0.727357536520297\n",
      "epoch:1 test/auc:0.6860799064889367\n",
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:2 train/loss:0.20595507\n",
      "epoch:2 val/auc:0.691802226943295\n",
      "epoch:2 test/auc:0.6565749600201853\n",
      "Valid ratio: 0.7603235\n",
      "Task evaluated: 12/12\n",
      "Valid ratio: 0.7513818\n",
      "Task evaluated: 12/12\n",
      "epoch:3 train/loss:0.20529544\n",
      "epoch:3 val/auc:0.7298949686528418\n",
      "epoch:3 test/auc:0.6967547132444416\n"
     ]
    }
   ],
   "source": [
    "def train(exe, train_prog, model, train_dataset, featurizer):\n",
    "    data_gen = train_dataset.iter_batch(\n",
    "        batch_size=64, num_workers=4, shuffle=True, collate_fn=featurizer.collate_fn)\n",
    "    list_loss = []\n",
    "    for batch_id, feed_dict in enumerate(data_gen):\n",
    "        train_loss, = exe.run(train_prog, feed=feed_dict, fetch_list=[model.loss], return_numpy=False)\n",
    "        list_loss.append(np.array(train_loss).mean())\n",
    "    return np.mean(list_loss)\n",
    "\n",
    "def evaluate(exe, test_prog, model, test_dataset, featurizer):\n",
    "    \"\"\"\n",
    "    In the dataset, a proportion of labels are blank. So we use a `valid` tensor\n",
    "    to help eliminate these blank labels in both training and evaluation phase.\n",
    "    \n",
    "    Returns:\n",
    "        the average roc-auc of all sub-tasks.\n",
    "    \"\"\"\n",
    "    data_gen = test_dataset.iter_batch(\n",
    "    \t\tbatch_size=64, num_workers=4, shuffle=False, collate_fn=featurizer.collate_fn)\n",
    "    total_pred = []\n",
    "    total_label = []\n",
    "    total_valid = []\n",
    "    for batch_id, feed_dict in enumerate(data_gen):\n",
    "        pred, = exe.run(test_prog, feed=feed_dict, fetch_list=[model.pred], return_numpy=False)\n",
    "        total_pred.append(np.array(pred))\n",
    "        total_label.append(feed_dict['finetune_label'])\n",
    "        total_valid.append(feed_dict['valid'])\n",
    "    total_pred = np.concatenate(total_pred, 0)\n",
    "    total_label = np.concatenate(total_label, 0)\n",
    "    total_valid = np.concatenate(total_valid, 0)\n",
    "    return calc_rocauc_score(total_label, total_pred, total_valid)\n",
    "\n",
    "for epoch_id in range(4):\n",
    "    train_loss = train(exe, train_prog, model, train_dataset, featurizer)\n",
    "    val_auc = evaluate(exe, test_prog, model, valid_dataset, featurizer)\n",
    "    test_auc = evaluate(exe, test_prog, model, test_dataset, featurizer)\n",
    "    print(\"epoch:%s train/loss:%s\" % (epoch_id, train_loss))\n",
    "    print(\"epoch:%s val/auc:%s\" % (epoch_id, val_auc))\n",
    "    print(\"epoch:%s test/auc:%s\" % (epoch_id, test_auc))\n",
    "fluid.io.save_params(exe, './model/tox21', train_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第三部分：下游任务模型预测\n",
    "在这部分，我们将简单介绍如何利用训好的下游任务模型来对给定的 SMILES 序列做预测。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建静态图\n",
    "这部分跟第二部分的建图部分基本相同。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/ipykernel/ipkernel.py:287: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n",
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/model_zoo/pretrain_gnns_model.py:78\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n",
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/model_zoo/pretrain_gnns_model.py:98\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n",
      "/home/ol/anaconda2/envs/paddle2.0/lib/python3.7/site-packages/paddle/fluid/layers/math_op_patch.py:298: UserWarning: /home/ol/jieqiong/repos/PaddleHelix/tutorials/../pahelix/networks/gnn_block.py:194\n",
      "The behavior of expression A + B has been unified with elementwise_add(X, Y, axis=-1) from Paddle 2.0. If your code works well in the older versions but crashes in this version, try to use elementwise_add(X, Y, axis=0) instead of A + B. This transitional warning will be dropped in the future.\n",
      "  op_type, op_type, EXPRESSION_MAP[method_name]))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config = {\n",
    "    \"dropout_rate\": 0.5,# dropout rate\n",
    "    \"gnn_type\": \"gin\",  # other choices like \"gat\", \"gcn\".\n",
    "    \"layer_num\": 5,     # the number of gnn layers.\n",
    "    \"num_tasks\": len(task_names), # number of targets to predict for the downstream task.\n",
    "}\n",
    "inference_prog = fluid.Program()\n",
    "startup_prog = fluid.Program()\n",
    "with fluid.program_guard(inference_prog, startup_prog):\n",
    "    with fluid.unique_name.guard():\n",
    "        model = DownstreamModel(model_config=model_config)\n",
    "        model.inference()\n",
    "\n",
    "exe = fluid.Executor(fluid.CPUPlace())\n",
    "# exe = fluid.Executor(fluid.CUDAPlace(0))\n",
    "exe.run(startup_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载训练好的下游任务模型\n",
    "加载在第二部分中训练好的下游任务模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load parameters from ./model/tox21.\n"
     ]
    }
   ],
   "source": [
    "load_partial_params(exe, './model/tox21', inference_prog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 开始预测\n",
    "对给定的 SMILES 序列进行预测。我们直接调用 `DownstreamFeaturizer` 的 `gen_features` 和 `collate_fn` 函数将原始的 SMILES 序列转化为模型的输入。\n",
    "\n",
    "以 Tox21 数据集为例，我们的下游任务模型可以给出 Tox21 里面的12个子任务的预测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SMILES:O=C1c2ccccc2C(=O)C1c1ccc2cc(S(=O)(=O)[O-])cc(S(=O)(=O)[O-])c2n1\n",
      "Predictions:\n",
      "  NR-AR:\t0.03680832\n",
      "  NR-AR-LBD:\t0.026945854\n",
      "  NR-AhR:\t0.3054564\n",
      "  NR-Aromatase:\t0.067290716\n",
      "  NR-ER:\t0.28590235\n",
      "  NR-ER-LBD:\t0.07384003\n",
      "  NR-PPAR-gamma:\t0.02805905\n",
      "  SR-ARE:\t0.31503102\n",
      "  SR-ATAD5:\t0.056458935\n",
      "  SR-HSE:\t0.050735578\n",
      "  SR-MMP:\t0.33103985\n",
      "  SR-p53:\t0.09834782\n"
     ]
    }
   ],
   "source": [
    "SMILES=\"O=C1c2ccccc2C(=O)C1c1ccc2cc(S(=O)(=O)[O-])cc(S(=O)(=O)[O-])c2n1\"\n",
    "featurizer = DownstreamFeaturizer(model.graph_wrapper, is_inference=True)\n",
    "feed_dict = featurizer.collate_fn([featurizer.gen_features({'smiles': SMILES})])\n",
    "pred, = exe.run(inference_prog, feed=feed_dict, fetch_list=[model.pred], return_numpy=False)\n",
    "probs = np.array(pred)[0]\n",
    "print('SMILES:%s' % SMILES)\n",
    "print('Predictions:')\n",
    "for name, prob in zip(task_names, probs):\n",
    "    print(\"  %s:\\t%s\" % (name, prob))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}